{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convolutional Neural Networks for Sentence Classification using Ignite"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a tutorial on using Ignite to train neural network models, setup experiments and validate models.\n",
    "\n",
    "In this experiment, we'll be replicating [\n",
    "Convolutional Neural Networks for Sentence Classification by Yoon Kim](https://arxiv.org/abs/1408.5882)! This paper uses CNN for text classification, a task typically reserved for RNNs, Logistic Regression, Naive Bayes.\n",
    "\n",
    "We want to be able to classify IMDB movie reviews and predict whether the review is positive or negative. IMDB Movie Review dataset comprises of 25000 positive and 25000 negative examples. The dataset comprises of text and label pairs. This is binary classification problem. We'll be using PyTorch to create the model, torchtext to import data and Ignite to train and monitor the models!\n",
    "\n",
    "Lets get started! "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Required Dependencies \n",
    "\n",
    "In this example we only need torchvision package, assuming that `torch` and `ignite` are already installed. We can install it using `pip`:\n",
    "\n",
    "`pip install torchtext spacy`\n",
    "\n",
    "`python -m spacy download en`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`torchtext` is a library that provides multiple datasets for NLP tasks, similar to `torchvision`. Below we import the following:\n",
    "* **data**: A module to setup the data in the form Fields and Labels.\n",
    "* **datasets**: A module to download NLP datasets.\n",
    "* **GloVe**: A module to download and use pretrained GloVe embedings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchtext import data\n",
    "from torchtext import datasets\n",
    "from torchtext.vocab import GloVe"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We import torch, nn and functional modules to create our models! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "SEED = 1234\n",
    "torch.manual_seed(SEED)\n",
    "torch.cuda.manual_seed(SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Ignite` is a High-level library to help with training neural networks in PyTorch. It comes with an `Engine` to setup a training loop, various metrics, handlers and a helpful contrib section! \n",
    "\n",
    "Below we import the following:\n",
    "* **Engine**: Runs a given process_function over each batch of a dataset, emitting events as it goes.\n",
    "* **Events**: Allows users to attach functions to an `Engine` to fire functions at a specific event. Eg: `EPOCH_COMPLETED`, `ITERATION_STARTED`, etc.\n",
    "* **Accuracy**: Metric to calculate accuracy over a dataset, for binary, multiclass, multilabel cases. \n",
    "* **Loss**: General metric that takes a loss function as a parameter, calculate loss over a dataset.\n",
    "* **RunningAverage**: General metric to attach to Engine during training. \n",
    "* **ModelCheckpoint**: Handler to checkpoint models. \n",
    "* **EarlyStopping**: Handler to stop training based on a score function. \n",
    "* **ProgressBar**: Handler to create a tqdm progress bar."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ignite.engine import Engine, Events\n",
    "from ignite.metrics import Accuracy, Loss, RunningAverage\n",
    "from ignite.handlers import ModelCheckpoint, EarlyStopping\n",
    "from ignite.contrib.handlers import ProgressBar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Processing Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code below first sets up `TEXT` and `LABEL` as general data objects. \n",
    "\n",
    "* `TEXT` converts any text to lowercase and produces tensors with the batch dimension first. \n",
    "* `LABEL` is a data object that will convert any labels to floats.\n",
    "\n",
    "Next IMDB training and test datasets are downloaded, the training data is split into training and validation datasets. It takes TEXT and LABEL as inputs so that the data is processed as specified. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "TEXT = data.Field(lower=True, batch_first=True)\n",
    "LABEL = data.LabelField(dtype=torch.float)\n",
    "\n",
    "train_data, test_data = datasets.IMDB.splits(TEXT, LABEL, root='/tmp/imdb/')\n",
    "train_data, valid_data = train_data.split(split_ratio=0.8, random_state=random.seed(SEED))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we have three sets of the data - train, validation, test. Let's explore what these data objects are and how to extract data from them.\n",
    "* `train_data` is **torchtext.data.dataset.Dataset**, this is similar to **torch.utils.data.Dataset**.\n",
    "* `train_data[0]` is **torchtext.data.example.Example**, a Dataset is comprised of many Examples.\n",
    "* Let's explore the attributes of an example. We see a few methods, but most importantly we see `label` and `text`.\n",
    "* `example.text` is the text of that example and `example.label` is the label of the example. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "type(train_data) :  <class 'torchtext.data.dataset.Dataset'>\n",
      "type(train_data[0]) :  <class 'torchtext.data.example.Example'>\n",
      "Attributes of Example :  ['fromCSV', 'fromJSON', 'fromdict', 'fromlist', 'fromtree', 'label', 'text']\n",
      "example.label :  neg\n",
      "example.text[:10] :  ['i', 'sat', 'glued', 'to', 'the', 'screen,', 'riveted,', 'yawning,', 'yet', 'keeping']\n"
     ]
    }
   ],
   "source": [
    "print('type(train_data) : ', type(train_data))\n",
    "print('type(train_data[0]) : ',type(train_data[0]))\n",
    "example = train_data[0]\n",
    "print('Attributes of Example : ', [attr for attr in dir(example) if '_' not in attr])\n",
    "print('example.label : ', example.label)\n",
    "print('example.text[:10] : ', example.text[:10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have an idea of what are split datasets look like, lets dig further into `TEXT` and `LABEL`. It is important that we build our vocabulary based on the train dataset as validation and test are **unseen** in our experimenting. \n",
    "\n",
    "For `TEXT`, let's download the pretrained **GloVE** 100 dimensional word vectors. This means each word is described by 100 floats! If you want to read more about this, here are a few resources.\n",
    "* [StanfordNLP - GloVe](https://github.com/stanfordnlp/GloVe)\n",
    "* [DeepLearning.ai Lecture](https://www.coursera.org/lecture/nlp-sequence-models/glove-word-vectors-IxDTG)\n",
    "* [Stanford CS224N Lecture by Richard Socher](https://www.youtube.com/watch?v=ASn7ExxLZws)\n",
    "\n",
    "We use `TEXT.build_vocab` to do this, let's explore the `TEXT` object more. \n",
    "\n",
    "Let's explore what `TEXT` object is and how to extract data from them. We see `TEXT` has a few attributes, let's explore vocab, since we just used the build_vocab function. \n",
    "\n",
    "`TEXT.vocab` has the following attributes:\n",
    "* `extend` is used to extend the vocabulary\n",
    "* `freqs` is a dictionary of the frequency of each word\n",
    "* `itos` is a list of all the words in the vocabulary.\n",
    "* `stoi` is a dictionary mapping every word to an index.\n",
    "* `vectors` is a torch.Tensor of the downloaded embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Attributes of TEXT :  ['dtype', 'dtypes', 'lower', 'numericalize', 'pad', 'postprocessing', 'preprocess', 'preprocessing', 'process', 'sequential', 'tokenize', 'vocab']\n",
      "Attributes of TEXT.vocab :  ['extend', 'freqs', 'itos', 'stoi', 'vectors']\n",
      "First 5 values TEXT.vocab.itos :  ['<unk>', '<pad>', 'the', 'a', 'and']\n",
      "First 5 key, value pairs of TEXT.vocab.stoi :  {'derived)': 134015, 'blum,': 120130, 'cartland': 50785, 'eccentric': 4902, 'designs': 6221}\n",
      "Shape of TEXT.vocab.vectors.shape :  torch.Size([218957, 100])\n"
     ]
    }
   ],
   "source": [
    "TEXT.build_vocab(train_data, vectors=GloVe(name='6B', dim=100, cache='/tmp/glove/'))\n",
    "print ('Attributes of TEXT : ', [attr for attr in dir(TEXT) if '_' not in attr])\n",
    "print ('Attributes of TEXT.vocab : ', [attr for attr in dir(TEXT.vocab) if '_' not in attr])\n",
    "print ('First 5 values TEXT.vocab.itos : ', TEXT.vocab.itos[0:5]) \n",
    "print ('First 5 key, value pairs of TEXT.vocab.stoi : ', {key:value for key,value in list(TEXT.vocab.stoi.items())[0:5]}) \n",
    "print ('Shape of TEXT.vocab.vectors.shape : ', TEXT.vocab.vectors.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's do the same with `LABEL`. We see that there are vectors related to `LABEL`, this is expected because `LABEL` provides the definition of a label of data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Attributes of LABEL :  ['dtype', 'dtypes', 'lower', 'numericalize', 'pad', 'postprocessing', 'preprocess', 'preprocessing', 'process', 'sequential', 'tokenize', 'vocab']\n",
      "Attributes of LABEL.vocab :  ['extend', 'freqs', 'itos', 'stoi', 'vectors']\n",
      "First 5 values LABEL.vocab.itos :  ['neg', 'pos']\n",
      "First 5 key, value pairs of LABEL.vocab.stoi :  {'neg': 0, 'pos': 1}\n",
      "Shape of LABEL.vocab.vectors :  None\n"
     ]
    }
   ],
   "source": [
    "LABEL.build_vocab(train_data)\n",
    "print ('Attributes of LABEL : ', [attr for attr in dir(LABEL) if '_' not in attr])\n",
    "print ('Attributes of LABEL.vocab : ', [attr for attr in dir(LABEL.vocab) if '_' not in attr])\n",
    "print ('First 5 values LABEL.vocab.itos : ', LABEL.vocab.itos) \n",
    "print ('First 5 key, value pairs of LABEL.vocab.stoi : ', {key:value for key,value in list(LABEL.vocab.stoi.items())})\n",
    "print ('Shape of LABEL.vocab.vectors : ', LABEL.vocab.vectors)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we must convert our split datasets into iterators, we'll take advantage of **torchtext.data.BucketIterator**! BucketIterator pads every element of a batch to the length of the longest element of the batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_data, valid_data, test_data), \n",
    "                                                                           batch_size=32,\n",
    "                                                                           device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's actually explore what the output of the iterator is, this way we'll know what the input of the model is, how to compare the label to the output and how to setup are process_functions for Ignite's `Engine`.\n",
    "* `batch.label[0]` is the label of a single example. We can see that `LABEL.vocab.stoi` was used to map the label that originally text into a float.\n",
    "* `batch.text[0]` is the text of a single example. Similar to label, `TEXT.vocab.stoi` was used to convert each token of the example's text into indices.\n",
    "\n",
    "Now let's print the lengths of the sentences of the first 10 batches of `train_iterator`. We see here that all the batches are of different lengths, this means that the bucket iterator is doing exactly as we hoped!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batch.label[0] :  tensor(1., device='cuda:0')\n",
      "batch.text[0] :  tensor([     9,   1107,    205,     10,     24,      4,    454,     12,     49,\n",
      "          1271,     12,    346,      3,   2793,      6,   3109,     20,    240,\n",
      "             4,     79,     76,    210,      7,     61,    653,     50,     25,\n",
      "            26,      3,    469,    502,    264,    313,      3,    457,  21725,\n",
      "             6,     73,   1090,     92,     89,      6,    209,     67,     12,\n",
      "             4,    181,      2,     20,  32092,   4102, 173802,    260,     31,\n",
      "         16607,      4,     21,   2546,  33723,     69,    748,  44230,     90,\n",
      "            32,    337,   2438,   1509,   4902, 170285,  21837,      3,  47953,\n",
      "           342,    260,     31,   1673,  10692,     14,      3,    448,   1018,\n",
      "          2410,      8,      2,     20,      4,     17,      2,    119,     57,\n",
      "           474,    215,   1779,     14,  12440,   3661,    260,     31,      2,\n",
      "          2356,      4,   2333,   8324,  24453,  23239,     16,     41,    208,\n",
      "           300,    269,    303,      8,      2,     20,    701,      2,    119,\n",
      "            46,     25,    354,    562,     50,    145,      3,    290,     96,\n",
      "            10,     30,      7,     17,    929,    469,    502,    819,   2614,\n",
      "           502,     16,    311,    568,   8498], device='cuda:0')\n",
      "Lengths of first 10 batches :  [978, 625, 650, 450, 587, 671, 526, 984, 789, 911, 669]\n"
     ]
    }
   ],
   "source": [
    "batch = next(iter(train_iterator))\n",
    "print('batch.label[0] : ', batch.label[0])\n",
    "print('batch.text[0] : ', batch.text[0][batch.text[0] != 1])\n",
    "\n",
    "lengths = []\n",
    "for i, batch in enumerate(train_iterator):\n",
    "    x = batch.text\n",
    "    lengths.append(x.shape[1])\n",
    "    if i == 10:\n",
    "        break\n",
    "\n",
    "print ('Lengths of first 10 batches : ', lengths)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TextCNN Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is the replication of the model, here are the operations of the model:\n",
    "* **Embedding**: Embeds a batch of text of shape (N, L) to (N, L, D), where N is batch size, L is maximum length of the batch, D is the embedding dimension. \n",
    "\n",
    "* **Convolutions**: Runs parallel convolutions across the embedded words with kernel sizes of 3, 4, 5 to mimic trigrams, four-grams, five-grams. This results in outputs of (N, L - k + 1, D) per convolution, where k is the kernel_size. \n",
    "\n",
    "* **Activation**: ReLu activation is applied to each convolution operation.\n",
    "\n",
    "* **Pooling**: Runs parallel maxpooling operations on the activated convolutions with window sizes of L - k + 1, resulting in 1 value per channel i.e. a shape of (N, 1, D) per pooling. \n",
    "\n",
    "* **Concat**: The pooling outputs are concatenated and squeezed to result in a shape of (N, 3D). This is a single embedding for a sentence.\n",
    "\n",
    "* **Dropout**: Dropout is applied to the embedded sentence. \n",
    "\n",
    "* **Fully Connected**: The dropout output is passed through a fully connected layer of shape (3D, 1) to give a single output for each example in the batch. sigmoid is applied to the output of this layer.\n",
    "\n",
    "* **load_embeddings**: This is a method defined for TextCNN to load embeddings based on user input. There are 3 modes - rand which results in randomly initialized weights, static which results in frozen pretrained weights, nonstatic which results in trainable pretrained weights. \n",
    "\n",
    "\n",
    "Let's note that this model works for variable text lengths! The idea to embed the words of a sentence, use convolutions, maxpooling and concantenation to embed the sentence as a single vector! This single vector is passed through a fully connected layer with sigmoid to output a single value. This value can be interpreted as the probability a sentence is positive (closer to 1) or negative (closer to 0).\n",
    "\n",
    "The minimum length of text expected by the model is the size of the smallest kernel size of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TextCNN(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, kernel_sizes, num_filters, num_classes, d_prob, mode):\n",
    "        super(TextCNN, self).__init__()\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embedding_dim = embedding_dim\n",
    "        self.kernel_sizes = kernel_sizes\n",
    "        self.num_filters = num_filters\n",
    "        self.num_classes = num_classes\n",
    "        self.d_prob = d_prob\n",
    "        self.mode = mode\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=1)\n",
    "        self.load_embeddings()\n",
    "        self.conv = nn.ModuleList([nn.Conv1d(in_channels=embedding_dim,\n",
    "                                             out_channels=num_filters,\n",
    "                                             kernel_size=k, stride=1) for k in kernel_sizes])\n",
    "        self.dropout = nn.Dropout(d_prob)\n",
    "        self.fc = nn.Linear(len(kernel_sizes) * num_filters, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        batch_size, sequence_length = x.shape\n",
    "        x = self.embedding(x).transpose(1, 2)\n",
    "        x = [F.relu(conv(x)) for conv in self.conv]\n",
    "        x = [F.max_pool1d(c, c.size(-1)).squeeze(dim=-1) for c in x]\n",
    "        x = torch.cat(x, dim=1)\n",
    "        x = self.fc(self.dropout(x))\n",
    "        return torch.sigmoid(x).squeeze()\n",
    "\n",
    "    def load_embeddings(self):\n",
    "        if 'static' in self.mode:\n",
    "            self.embedding.weight.data.copy_(TEXT.vocab.vectors)\n",
    "            if 'non' not in self.mode:\n",
    "                self.embedding.weight.data.requires_grad = False\n",
    "                print('Loaded pretrained embeddings, weights are not trainable.')\n",
    "            else:\n",
    "                self.embedding.weight.data.requires_grad = True\n",
    "                print('Loaded pretrained embeddings, weights are trainable.')\n",
    "        elif self.mode == 'rand':\n",
    "            print('Randomly initialized embeddings are used.')\n",
    "        else:\n",
    "            raise ValueError('Unexpected value of mode. Please choose from static, nonstatic, rand.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating Model, Optimizer and Loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below we create an instance of the TextCNN model and load embeddings in **static** mode. The model is placed on a device and then a loss function of Binary Cross Entropy and Adam optimizer are setup. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained embeddings, weights are not trainable.\n"
     ]
    }
   ],
   "source": [
    "vocab_size, embedding_dim = TEXT.vocab.vectors.shape\n",
    "\n",
    "model = TextCNN(vocab_size=vocab_size,\n",
    "                embedding_dim=embedding_dim,\n",
    "                kernel_sizes=[3, 4, 5],\n",
    "                num_filters=100,\n",
    "                num_classes=1, \n",
    "                d_prob=0.5,\n",
    "                mode='static')\n",
    "model.to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)\n",
    "criterion = nn.BCELoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training and Evaluating using Ignite"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Trainer Engine - process_function\n",
    "\n",
    "Ignite's Engine allows user to define a process_function to process a given batch, this is applied to all the batches of the dataset. This is a general class that can be applied to train and validate models! A process_function has two parameters engine and batch. \n",
    "\n",
    "\n",
    "Let's walk through what the function of the trainer does:\n",
    "\n",
    "* Sets model in train mode. \n",
    "* Sets the gradients of the optimizer to zero.\n",
    "* Generate x and y from batch.\n",
    "* Performs a forward pass to calculate y_pred using model and x.\n",
    "* Calculates loss using y_pred and y.\n",
    "* Performs a backward pass using loss to calculate gradients for the model parameters.\n",
    "* model parameters are optimized using gradients and optimizer.\n",
    "* Returns scalar loss. \n",
    "\n",
    "Below is a single operation during the trainig process. This process_function will be attached to the training engine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_function(engine, batch):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    x, y = batch.text, batch.label\n",
    "    y_pred = model(x)\n",
    "    loss = criterion(y_pred, y)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluator Engine - process_function\n",
    "\n",
    "Similar to the training process function, we setup a function to evaluate a single batch. Here is what the eval_function does:\n",
    "\n",
    "* Sets model in eval mode.\n",
    "* Generates x and y from batch.\n",
    "* With torch.no_grad(), no gradients are calculated for any succeding steps.\n",
    "* Performs a forward pass on the model to calculate y_pred based on model and x.\n",
    "* Returns y_pred and y.\n",
    "\n",
    "Ignite suggests attaching metrics to evaluators and not trainers because during the training the model parameters are constantly changing and it is best to evaluate model on a stationary model. This information is important as there is a difference in the functions for training and evaluating. Training returns a single scalar loss. Evaluating returns y_pred and y as that output is used to calculate metrics per batch for the entire dataset.\n",
    "\n",
    "All metrics in Ignite require y_pred and y as outputs of the function attached to the Engine. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_function(engine, batch):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        x, y = batch.text, batch.label\n",
    "        y_pred = model(x)\n",
    "        return y_pred, y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Instantiating Training and Evaluating Engines\n",
    "\n",
    "Below we create 3 engines, a trainer, a training evaluator and a validation evaluator. You'll notice that train_evaluator and validation_evaluator use the same function, we'll see later why this was done! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Engine(process_function)\n",
    "train_evaluator = Engine(eval_function)\n",
    "validation_evaluator = Engine(eval_function)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Metrics - RunningAverage, Accuracy and Loss\n",
    "\n",
    "To start, we'll attach a metric of Running Average to track a running average of the scalar loss output for each batch. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now there are two metrics that we want to use for evaluation - accuracy and loss. This is a binary problem, so for Loss we can simply pass the Binary Cross Entropy function as the loss_function. \n",
    "\n",
    "For Accuracy, Ignite requires y_pred and y to be comprised of 0's and 1's only. Since our model outputs from a sigmoid layer, values are between 0 and 1. We'll need to write a function that transforms `engine.state.output` which is comprised of y_pred and y. \n",
    "\n",
    "Below `thresholded_output_transform` does just that, it rounds y_pred to convert y_pred to 0's and 1's, and then returns rounded y_pred and y. This function is the output_transform function used to transform the `engine.state.output` to achieve `Accuracy`'s desired purpose.\n",
    "\n",
    "Now, we attach `Loss` and `Accuracy` (with `thresholded_output_transform`) to train_evaluator and validation_evaluator. \n",
    "\n",
    "To attach a metric to engine, the following format is used:\n",
    "* `Metric(output_transform=output_transform, ...).attach(engine, 'metric_name')`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def thresholded_output_transform(output):\n",
    "    y_pred, y = output\n",
    "    y_pred = torch.round(y_pred)\n",
    "    return y_pred, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "Accuracy(output_transform=thresholded_output_transform).attach(train_evaluator, 'accuracy')\n",
    "Loss(criterion).attach(train_evaluator, 'bce')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "Accuracy(output_transform=thresholded_output_transform).attach(validation_evaluator, 'accuracy')\n",
    "Loss(criterion).attach(validation_evaluator, 'bce')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Progress Bar\n",
    "\n",
    "Next we create an instance of Ignite's progess bar and attach it to the trainer and pass it a key of `engine.state.metrics` to track. In this case, the progress bar will be tracking `engine.state.metrics['loss']`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/tqdm/autonotebook/__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
      "  \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n"
     ]
    }
   ],
   "source": [
    "pbar = ProgressBar(persist=True, bar_format=\"\")\n",
    "pbar.attach(trainer, ['loss'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### EarlyStopping - Tracking Validation Loss\n",
    "\n",
    "Now we'll setup a Early Stopping handler for this training process. EarlyStopping requires a score_function that allows the user to define whatever criteria to stop trainig. In this case, if the loss of the validation set does not decrease in 5 epochs, the training process will stop early.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_function(engine):\n",
    "    val_loss = engine.state.metrics['bce']\n",
    "    return -val_loss\n",
    "\n",
    "handler = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)\n",
    "validation_evaluator.add_event_handler(Events.COMPLETED, handler)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attaching Custom Functions to Engine at specific Events\n",
    "\n",
    "Below you'll see ways to define your own custom functions and attaching them to various `Events` of the training process.\n",
    "\n",
    "The functions below both achieve similar tasks, they print the results of the evaluator run on a dataset. One function does that on the training evaluator and dataset, while the other on the validation. Another difference is how these functions are attached in the trainer engine.\n",
    "\n",
    "The first method involves using a decorator, the syntax is simple - `@` `trainer.on(Events.EPOCH_COMPLETED)`, means that the decorated function will be attached to the trainer and called at the end of each epoch. \n",
    "\n",
    "The second method involves using the add_event_handler method of trainer - `trainer.add_event_handler(Events.EPOCH_COMPLETED, custom_function)`. This achieves the same result as the above. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "@trainer.on(Events.EPOCH_COMPLETED)\n",
    "def log_training_results(engine):\n",
    "    train_evaluator.run(train_iterator)\n",
    "    metrics = train_evaluator.state.metrics\n",
    "    avg_accuracy = metrics['accuracy']\n",
    "    avg_bce = metrics['bce']\n",
    "    pbar.log_message(\n",
    "        \"Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}\"\n",
    "        .format(engine.state.epoch, avg_accuracy, avg_bce))\n",
    "    \n",
    "def log_validation_results(engine):\n",
    "    validation_evaluator.run(valid_iterator)\n",
    "    metrics = validation_evaluator.state.metrics\n",
    "    avg_accuracy = metrics['accuracy']\n",
    "    avg_bce = metrics['bce']\n",
    "    pbar.log_message(\n",
    "        \"Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}\"\n",
    "        .format(engine.state.epoch, avg_accuracy, avg_bce))\n",
    "    pbar.n = pbar.last_print_n = 0\n",
    "\n",
    "trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ModelCheckpoint\n",
    "\n",
    "Lastly, we want to checkpoint this model. It's important to do so, as training processes can be time consuming and if for some reason something goes wrong during training, a model checkpoint can be helpful to restart training from the point of failure.\n",
    "\n",
    "Below we'll use Ignite's `ModelCheckpoint` handler to checkpoint models at the end of each epoch. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpointer = ModelCheckpoint('/tmp/models', 'textcnn', save_interval=1, n_saved=2, create_dir=True, save_as_state_dict=True)\n",
    "trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'textcnn': model})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run Engine\n",
    "\n",
    "Next, we'll run the trainer for 10 epochs and monitor results. Below we can see that progess bar prints the loss per iteration, and prints the results of training and validation as we specified in our custom function. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fff122f9c62148adac293b981b0c0af6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 1  Avg accuracy: 0.76 Avg loss: 0.64\n",
      "Validation Results - Epoch: 1  Avg accuracy: 0.73 Avg loss: 0.64\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ff8745963cce4e799d4e804e427c39b4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 2  Avg accuracy: 0.79 Avg loss: 0.56\n",
      "Validation Results - Epoch: 2  Avg accuracy: 0.78 Avg loss: 0.56\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e1d758e6ebb44bbaaf63369e0a723040",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 3  Avg accuracy: 0.80 Avg loss: 0.48\n",
      "Validation Results - Epoch: 3  Avg accuracy: 0.79 Avg loss: 0.49\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "31b248535085452a9ece3521c0f19704",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 4  Avg accuracy: 0.81 Avg loss: 0.44\n",
      "Validation Results - Epoch: 4  Avg accuracy: 0.80 Avg loss: 0.45\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7370678ac46f475998dc2f874eedda7d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 5  Avg accuracy: 0.82 Avg loss: 0.42\n",
      "Validation Results - Epoch: 5  Avg accuracy: 0.81 Avg loss: 0.43\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3c533b8a1c594e028be9422ed55a8bbc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 6  Avg accuracy: 0.83 Avg loss: 0.40\n",
      "Validation Results - Epoch: 6  Avg accuracy: 0.81 Avg loss: 0.42\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1597ea3c65034cc4bbbcb8c013adb96c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 7  Avg accuracy: 0.84 Avg loss: 0.39\n",
      "Validation Results - Epoch: 7  Avg accuracy: 0.82 Avg loss: 0.41\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0bad52713dd496fab612d43cbc8f0ba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 8  Avg accuracy: 0.85 Avg loss: 0.37\n",
      "Validation Results - Epoch: 8  Avg accuracy: 0.82 Avg loss: 0.40\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ed22e1ce816a4aa0b885cc80702860a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 9  Avg accuracy: 0.85 Avg loss: 0.36\n",
      "Validation Results - Epoch: 9  Avg accuracy: 0.83 Avg loss: 0.39\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "31b84c5326e94efc9fa6b6ac8a176fd1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 10  Avg accuracy: 0.85 Avg loss: 0.35\n",
      "Validation Results - Epoch: 10  Avg accuracy: 0.83 Avg loss: 0.38\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a231ba4c28a243db868f676c784411c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 11  Avg accuracy: 0.86 Avg loss: 0.34\n",
      "Validation Results - Epoch: 11  Avg accuracy: 0.83 Avg loss: 0.37\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5650adc8034e4d3686da74ae4a21aaf5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 12  Avg accuracy: 0.87 Avg loss: 0.33\n",
      "Validation Results - Epoch: 12  Avg accuracy: 0.84 Avg loss: 0.37\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c6e6108d2fa54175b0fcd57e4af13362",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 13  Avg accuracy: 0.87 Avg loss: 0.32\n",
      "Validation Results - Epoch: 13  Avg accuracy: 0.84 Avg loss: 0.36\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7e2828d9042a4d8a9ec0e189334c7800",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 14  Avg accuracy: 0.87 Avg loss: 0.32\n",
      "Validation Results - Epoch: 14  Avg accuracy: 0.84 Avg loss: 0.36\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "379d84a46a1f421990e63821074da8c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 15  Avg accuracy: 0.88 Avg loss: 0.31\n",
      "Validation Results - Epoch: 15  Avg accuracy: 0.84 Avg loss: 0.35\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f5e8535883024e9aa3875e77cd3e294a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 16  Avg accuracy: 0.88 Avg loss: 0.30\n",
      "Validation Results - Epoch: 16  Avg accuracy: 0.85 Avg loss: 0.34\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3c4372380e884c30b4fd25cbfa146d0e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 17  Avg accuracy: 0.89 Avg loss: 0.29\n",
      "Validation Results - Epoch: 17  Avg accuracy: 0.85 Avg loss: 0.34\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9e1d44194fc7496593d23b24019edc45",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 18  Avg accuracy: 0.89 Avg loss: 0.28\n",
      "Validation Results - Epoch: 18  Avg accuracy: 0.85 Avg loss: 0.34\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "372de51144b746c2a49948d9e2dff360",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 19  Avg accuracy: 0.90 Avg loss: 0.27\n",
      "Validation Results - Epoch: 19  Avg accuracy: 0.85 Avg loss: 0.33\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "949c2fe98ebc44f9a943697c585f9b83",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=625), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training Results - Epoch: 20  Avg accuracy: 0.90 Avg loss: 0.26\n",
      "Validation Results - Epoch: 20  Avg accuracy: 0.86 Avg loss: 0.33\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<ignite.engine.engine.State at 0x7f40aad2c400>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.run(train_iterator, max_epochs=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's it! We have successfully trained and evaluated a Convolutational Neural Network for Text Classification. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
